package http

import (
	"bytes"
	"context"
	"crypto/tls"
	"errors"
	"fmt"
	"gitee.com/llakcs/agile-go/encoding"
	"gitee.com/llakcs/agile-go/middleware"
	"gitee.com/llakcs/agile-go/registry"
	"gitee.com/llakcs/agile-go/selector"
	"gitee.com/llakcs/agile-go/transport/target"
	"gitee.com/llakcs/agile-go/utils"
	"io"
	"net"
	"net/http"
	"strings"
	"time"
)

func init() {
	if selector.GlobalSelector() == nil {
		selector.SetGlobalSelector(selector.NewWeightedRoundRobinBuilder())
	}
}

const (
	MediaTypeJSON              = "application/json"
	MediaTypeXML               = "application/xml"
	MediaTypeFormUrlEncoded    = "application/x-www-form-urlencoded"
	MediaTypeMultipartFormData = "multipart/form-data"
	MediaTypePlainText         = "text/plain"
	MediaTypeHTML              = "text/html"
)

type clientOptions struct {
	ctx        context.Context         //上下文
	tlsConf    *tls.Config             //tls配置
	timeout    time.Duration           //超时时间
	endpoint   string                  //ip:端口
	encoder    EncodeRequestFunc       //编码
	decoder    DecodeResponseFunc      //解码
	discovery  registry.Discovery      //服务发现
	middleware []middleware.Middleware //中间件
	transport  *http.Transport
	block      bool
}

type ClientOption func(*clientOptions)

type DecodeErrorFunc func(ctx context.Context, res *http.Response) error

// EncodeRequestFunc is request encode func.
type EncodeRequestFunc func(ctx context.Context, contentType string, in interface{}) (body []byte, err error)

// DecodeResponseFunc is response decode func.
type DecodeResponseFunc func(ctx context.Context, res *http.Response, out interface{}) error

func WithTransport(t *http.Transport) ClientOption {
	return func(o *clientOptions) {
		o.transport = t
	}
}

// 编码器
func WithEncodeRequestFunc(e EncodeRequestFunc) ClientOption {
	return func(o *clientOptions) {
		o.encoder = e
	}
}

// 解码器
func WithDecodeResponseFunc(d DecodeResponseFunc) ClientOption {
	return func(o *clientOptions) {
		o.decoder = d
	}
}

// WithTimeout with client request timeout.
func WithTimeout(d time.Duration) ClientOption {
	return func(o *clientOptions) {
		o.timeout = d
	}
}

// WithMiddleware with client middleware.
func WithMiddleware(m ...middleware.Middleware) ClientOption {
	return func(o *clientOptions) {
		o.middleware = m
	}
}

// WithEndpoint with client addr.
func WithEndpoint(endpoint string) ClientOption {
	return func(o *clientOptions) {
		o.endpoint = endpoint
	}
}

func WithDiscovery(d registry.Discovery) ClientOption {
	return func(o *clientOptions) {
		o.discovery = d
	}
}

// WithTLSConfig with tls config.
func WithTLSConfig(c *tls.Config) ClientOption {
	return func(o *clientOptions) {
		o.tlsConf = c
	}
}

func WithBlock() ClientOption {
	return func(o *clientOptions) {
		o.block = true
	}
}

type Client struct {
	r        *Resolver
	opts     *clientOptions
	target   *target.Target
	client   *http.Client
	insecure bool
}

func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) {
	tr := &http.Transport{
		MaxIdleConns:          100,
		IdleConnTimeout:       90 * time.Second, // 空闲连接超时时间
		TLSHandshakeTimeout:   10 * time.Second, // TLS 握手超时时间
		ExpectContinueTimeout: 1 * time.Second,  // Expect Continue 超时时间
		DialContext: (&net.Dialer{
			Timeout:   30 * time.Second, // 拨号超时时间
			KeepAlive: 30 * time.Second, // Keep-Alive 超时时间
		}).DialContext,
		ForceAttemptHTTP2: true,
	}
	options := &clientOptions{
		ctx:       ctx,
		timeout:   5000 * time.Millisecond,
		transport: tr,
		encoder:   DefaultRequestEncoder,
		decoder:   DefaultResponseDecoder,
	}
	for _, o := range opts {
		o(options)
	}

	if options.tlsConf != nil {
		tr.TLSClientConfig = options.tlsConf
	}
	insecure := options.tlsConf == nil
	target, err := target.ParseTarget(options.endpoint, insecure, false)
	if err != nil {
		return nil, err
	}
	//算法选择器
	selector := selector.GlobalSelector().Build()
	var r *Resolver
	if options.discovery != nil {
		if target.Scheme == "discovery" {
			if r, err = NewResolver(ctx, options.discovery, target, options.block, insecure, selector); err != nil {
				return nil, fmt.Errorf("[http client] new resolver failed!err: %v", options.endpoint)
			}
		} else if _, _, err := utils.ExtractHostPort(options.endpoint); err != nil {
			return nil, fmt.Errorf("[http client] invalid endpoint format: %v", options.endpoint)
		}
	}
	return &Client{
		opts:     options,
		r:        r,
		target:   target,
		insecure: insecure,
		client: &http.Client{
			Timeout:   options.timeout,
			Transport: options.transport,
		},
	}, nil
}

/**
 * 调用这个方法要注意编解码
 */
func (c *Client) Invoke(ctx context.Context, method, contentType, path string, args interface{}, reply interface{}) error {
	var body io.Reader
	if args != nil {
		data, err := c.opts.encoder(ctx, contentType, args)
		if err != nil {
			return err
		}
		body = bytes.NewReader(data)
	}
	url := fmt.Sprintf("%s://%s%s", c.target.Scheme, c.target.Authority, path)
	req, err := http.NewRequest(method, url, body)
	if err != nil {
		return err
	}
	if contentType != "" {
		req.Header.Set("Content-Type", contentType)
	}
	if c.insecure {
		req.URL.Scheme = "http"
	} else {
		req.URL.Scheme = "https"
	}
	h := func(ctx context.Context, in interface{}) (interface{}, error) {
		res, reqErr := c.do(req.WithContext(ctx))
		if reqErr != nil {
			return nil, reqErr
		}
		defer res.Body.Close()
		if res.StatusCode < 200 || res.StatusCode > 299 {
			detail := ""
			data, rErr := io.ReadAll(res.Body)
			if rErr == nil {
				detail = string(data)
			}
			errMsg := "status:" + res.Status + " detail:" + detail
			return nil, errors.New(errMsg)
		}
		if dErr := c.opts.decoder(ctx, res, reply); dErr != nil {
			return nil, dErr
		}
		return reply, nil
	}
	var p selector.Peer
	ctx = selector.NewPeerContext(ctx, &p)
	if len(c.opts.middleware) > 0 {
		h = middleware.Chain(c.opts.middleware...)(h)
	}
	_, err = h(ctx, args)
	return err
}

func (c *Client) do(req *http.Request) (*http.Response, error) {
	if c.r != nil {
		var (
			err  error
			node selector.Node
		)
		node, err = c.r.Selector.Next(req.Context())
		if err != nil {
			return nil, err
		}
		if c.insecure {
			req.URL.Scheme = "http"
		} else {
			req.URL.Scheme = "https"
		}
		req.URL.Host = node.Address()
		req.Host = node.Address()
		fmt.Println("###addr:", node.Address())
	}
	return c.client.Do(req)
}

func ContentSubtype(contentType string) string {
	left := strings.Index(contentType, "/")
	if left == -1 {
		return ""
	}
	right := strings.Index(contentType, ";")
	if right == -1 {
		right = len(contentType)
	}
	if right < left {
		return ""
	}
	return contentType[left+1 : right]
}

// 默认解析http返回值
func parseResponse(r *http.Response) encoding.Codec {
	codec := encoding.GetCodec(ContentSubtype(r.Header.Get("Content-Type")))
	if codec != nil {
		return codec
	}
	//获取json编解码插件
	return encoding.GetCodec("json")
}

// 默认编码
func encodeArg(contentType string) encoding.Codec {
	name := ContentSubtype(contentType)
	codec := encoding.GetCodec(name)
	if codec != nil {
		return codec
	}
	//获取json编解码插件
	return encoding.GetCodec("json")
}

/**
 *  默认编码器
 */
func DefaultRequestEncoder(_ context.Context, contentType string, in interface{}) ([]byte, error) {
	return encodeArg(contentType).Marshal(in)
}

/**
 *  默认解码器
 */
func DefaultResponseDecoder(_ context.Context, res *http.Response, v interface{}) error {
	defer res.Body.Close()
	data, err := io.ReadAll(res.Body)
	if err != nil {
		return err
	}
	return parseResponse(res).Unmarshal(data, v)
}
